{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "069270c8-728b-4f9c-a1df-cfcc28c652a0",
   "metadata": {},
   "source": [
    "This is an example of train and test pipeline for GemNet-OC model.  \n",
    "Same task could be performed with pre-defined config from repository root:\n",
    "```bash\n",
    "python run.py --config-name gemnet-oc.yaml\n",
    "python run.py --config-name gemnet-oc_test.yaml\n",
    "```\n",
    "For detailed description please refer to [README](../nablaDFT/README.md).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a9274655-f8d0-47f5-9577-076c85a52ca6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from functools import partial\n",
    "\n",
    "import nablaDFT\n",
    "import pytorch_lightning as pl\n",
    "import torch\n",
    "import torchmetrics\n",
    "from nablaDFT import model_registry\n",
    "from nablaDFT.dataset import PyGNablaDFTDataModule, dataset_registry\n",
    "from nablaDFT.gemnet_oc import GemNetOC, GemNetOCLightning\n",
    "from nablaDFT.utils import seed_everything\n",
    "from omegaconf import OmegaConf\n",
    "from pytorch_lightning.callbacks import ModelCheckpoint\n",
    "from torch_ema import ExponentialMovingAverage"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a75e4459-4436-4877-b510-a413c0ae1f00",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_name = \"dataset_train_tiny\"  # Name of the training dataset\n",
    "datapath = \"database\"  # Path to the selected dataset\n",
    "nepochs = 200  # Number of epochs to train for\n",
    "seed = 1799  # Random seed number for reproducibility\n",
    "batch_size = 8  # Size of each batch for training\n",
    "num_workers = 2  # Dataloader's num_workers\n",
    "train_ratio = 0.9  # Part of dataset used for training\n",
    "val_ratio = 0.1  # Part of dataset used for validation\n",
    "devices = 1  # Number of GPU/TPU/CPU devices to use for training"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fd7a31fc-c15d-406a-9901-24728c77de96",
   "metadata": {},
   "source": [
    "This example uses *tiny* train dataset split.  \n",
    "All available dataset splits could be listed with:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f2ab09e-9b9d-41d7-93cf-ccd43e1aff89",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_registry.list_datasets(\"energy\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d531ffe3-2961-47d6-b174-c11e7a8119a6",
   "metadata": {},
   "source": [
    "## Downloading dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af0b26e2-3f98-4d00-8703-e361dba2e4ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed_everything(seed)\n",
    "\n",
    "datamodule = PyGNablaDFTDataModule(\n",
    "    datapath,\n",
    "    dataset_name,\n",
    "    train_size=train_ratio,\n",
    "    val_size=val_ratio,\n",
    "    seed=seed,\n",
    "    batch_size=batch_size,\n",
    "    num_workers=num_workers,\n",
    ")\n",
    "datamodule.setup(stage=\"fit\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cb759b67-cc6c-47d4-8c29-fd401d85a3f0",
   "metadata": {},
   "source": [
    "## Initialize model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ab300a5-8af5-4458-8be3-c5dfc7f53ee0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# GemNet-OC params\n",
    "model_cfg = {\n",
    "    \"num_targets\": 1,\n",
    "    \"num_spherical\": 7,\n",
    "    \"num_radial\": 128,\n",
    "    \"num_blocks\": 4,\n",
    "    \"emb_size_atom\": 256,\n",
    "    \"emb_size_edge\": 512,\n",
    "    \"emb_size_trip_in\": 64,\n",
    "    \"emb_size_trip_out\": 64,\n",
    "    \"emb_size_quad_in\": 32,\n",
    "    \"emb_size_quad_out\": 32,\n",
    "    \"emb_size_aint_in\": 64,\n",
    "    \"emb_size_aint_out\": 64,\n",
    "    \"emb_size_rbf\": 16,\n",
    "    \"emb_size_cbf\": 16,\n",
    "    \"emb_size_sbf\": 32,\n",
    "    \"num_before_skip\": 2,\n",
    "    \"num_after_skip\": 2,\n",
    "    \"num_concat\": 1,\n",
    "    \"num_atom\": 3,\n",
    "    \"num_output_afteratom\": 3,\n",
    "    \"num_atom_emb_layers\": 0,\n",
    "    \"num_global_out_layers\": 2,\n",
    "    \"regress_forces\": True,\n",
    "    \"direct_forces\": True,\n",
    "    \"use_pbc\": False,\n",
    "    \"scale_backprop_forces\": False,\n",
    "    \"cutoff\": 12.0,\n",
    "    \"cutoff_qint\": 12.0,\n",
    "    \"cutoff_aeaint\": 12.0,\n",
    "    \"cutoff_aint\": 12.0,\n",
    "    \"max_neighbors\": 30,\n",
    "    \"max_neighbors_qint\": 8,\n",
    "    \"max_neighbors_aeaint\": 20,\n",
    "    \"max_neighbors_aint\": 1000,\n",
    "    \"enforce_max_neighbors_strictly\": True,\n",
    "    \"rbf\": {\"name\": \"gaussian\"},\n",
    "    \"rbf_spherical\": None,\n",
    "    \"envelope\": {\"name\": \"polynomial\", \"exponent\": 5},\n",
    "    \"cbf\": {\"name\": \"spherical_harmonics\"},\n",
    "    \"sbf\": {\"name\": \"legendre_outer\"},\n",
    "    \"extensive\": True,\n",
    "    \"forces_coupled\": True,\n",
    "    \"output_init\": \"HeOrthogonal\",\n",
    "    \"activation\": \"silu\",\n",
    "    \"scale_file\": None,\n",
    "    \"quad_interaction\": True,\n",
    "    \"atom_edge_interaction\": True,\n",
    "    \"edge_atom_interaction\": True,\n",
    "    \"atom_interaction\": True,\n",
    "    \"scale_basis\": True,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "767f769e-6ae0-4250-b2c2-5e572df33c24",
   "metadata": {},
   "outputs": [],
   "source": [
    "opt_args = {\"amsgrad\": True, \"betas\": [0.9, 0.95], \"lr\": 1e-3, \"weight_decay\": 0}\n",
    "lr_args = {\"factor\": 0.8, \"patience\": 10}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d608fe91-490c-4197-adc8-da92c41717e3",
   "metadata": {},
   "outputs": [],
   "source": [
    "net = GemNetOC(**model_cfg)\n",
    "# optimizer, scheduler and EMA instantiated during Lightning module creation\n",
    "optimizer = partial(torch.optim.AdamW, **opt_args)\n",
    "lr_scheduler = partial(torch.optim.lr_scheduler.ReduceLROnPlateau, **lr_args)\n",
    "losses = {\"energy\": torch.nn.L1Loss(), \"forces\": nablaDFT.gemnet_oc.loss.L2Loss()}\n",
    "losses_coefs = {\"energy\": 1, \"forces\": 100}\n",
    "metric = torchmetrics.MultitaskWrapper(\n",
    "    task_metrics={\"energy\": torchmetrics.MeanAbsoluteError(), \"forces\": torchmetrics.MeanAbsoluteError()}\n",
    ")\n",
    "model = GemNetOCLightning(\n",
    "    \"GemNet-OC\",\n",
    "    net,\n",
    "    optimizer,\n",
    "    lr_scheduler,\n",
    "    losses,\n",
    "    metric,\n",
    "    losses_coefs,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13d124f1-a598-4f9a-abb4-e50196286e00",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_checkpoint = pl.callbacks.ModelCheckpoint(\n",
    "    monitor=\"val/loss\",\n",
    "    mode=\"min\",\n",
    "    save_top_k=1,\n",
    "    save_last=True,\n",
    "    dirpath=\"./checkpoints\",\n",
    "    filename=\"GemNet-OC_{epoch:03d}_{val_loss:4f}\",\n",
    ")\n",
    "early_stopping = pl.callbacks.EarlyStopping(\n",
    "    monitor=\"val/loss\", min_delta=1e-4, patience=50, mode=\"min\", check_on_train_epoch_end=False\n",
    ")\n",
    "callbacks = [model_checkpoint, early_stopping]\n",
    "logger = pl.loggers.TensorBoardLogger(save_dir=\"./tensorboard_logs\")\n",
    "\n",
    "trainer = pl.Trainer(\n",
    "    callbacks=callbacks,\n",
    "    logger=logger,\n",
    "    accelerator=\"gpu\",\n",
    "    max_epochs=nepochs,\n",
    "    gradient_clip_algorithm=\"norm\",\n",
    "    gradient_clip_val=5.0,\n",
    ")\n",
    "\n",
    "trainer.fit(model=model, datamodule=datamodule)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "402a0565-fd69-4ab6-86bb-6181428605db",
   "metadata": {},
   "outputs": [],
   "source": [
    "ckpt_path = trainer.checkpoint_callback.best_model_path"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fb3330af-925b-4b15-9f7c-308039d075c9",
   "metadata": {},
   "source": [
    "## Initializing testing procedure and computing the metric's result"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e92c1258-fe31-4b3b-abca-a87af51dccb2",
   "metadata": {},
   "source": [
    "We will use pretrained model for test with *model_registry* object.  \n",
    "You could list all available pretrained model with:  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb1ca3e1-0976-4871-a735-f5e3b46e21c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_registry.list_models()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "811609de-9c08-4c4b-ba0e-4475ce44053f",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = model_registry.get_pretrained_model(\"lightning\", \"GemNet-OC_train_tiny\")\n",
    "datamodule_test = PyGNablaDFTDataModule(\n",
    "    datapath,\n",
    "    \"dataset_test_conformations_tiny\",\n",
    "    batch_size=batch_size,\n",
    "    num_workers=num_workers,\n",
    ")\n",
    "trainer.test(model=model, datamodule=datamodule_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a9d5658-5074-4faa-a9e3-aa9a7a44628c",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
